import tensorflow as tf
from tensorflow.keras import Sequential, layers

from dataset.captcha.captcha import load_captcha

# 读入数据(24,72,3)
(x_train, t_train), (x_test, t_test) = load_captcha()

x_validation, t_validation = x_test, t_test

# 超参数
epochs = 100
batch_size = 128
learning_rate = 1e-1

network = Sequential([
    layers.Conv2D(12, 3, 1, activation=tf.nn.leaky_relu),
    layers.MaxPooling2D(strides=2),

    layers.BatchNormalization(),
    layers.Conv2D(36, 3, 3, activation=tf.nn.leaky_relu),

    layers.BatchNormalization(),
    layers.Conv2D(128, (3, 5), (1, 2), activation=tf.nn.leaky_relu),

    layers.Flatten(),
    layers.BatchNormalization(),
    layers.Dense(128 * 2),
    layers.BatchNormalization(),
    layers.Dense(4 * 36),
    layers.Reshape([4, 36])
])

network.build((None, 24, 72, 3))
network.summary()


def loss_func(y_true, y_pred):
    loss_ce = tf.losses.MSE(y_true, y_pred)
    loss_ce = tf.reduce_mean(loss_ce)
    return loss_ce


optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
network.compile(optimizer, loss=loss_func, metrics=['accuracy'])

network.fit(x_train, t_train, epochs=epochs, batch_size=batch_size,
            validation_data=(x_test, t_test))

network.evaluate(x_test, t_test)
network.save('model.h5')
